{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import datetime\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 23/374 [00:00<00:13, 26.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "wrong file: CZCE.JR003.csv\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 374/374 [00:10<00:00, 35.45it/s]\n"
     ]
    }
   ],
   "source": [
    "base_path = 'data'\n",
    "df_zhuli = pd.read_csv(os.path.join(base_path, 'zhuli.csv'))\n",
    "\n",
    "data_path = os.path.join(base_path, '1minute')\n",
    "df_min = []\n",
    "for path in tqdm(os.listdir(data_path)):\n",
    "    try:\n",
    "        df = pd.read_csv(os.path.join(data_path, path))\n",
    "        df.columns = ['datetime', 'open', 'high', 'low', 'close', 'volume', 'open_oi', 'close_oi']\n",
    "        df['ts_code'] = path[:-4]\n",
    "        df_min.append(df)\n",
    "    except:\n",
    "        print('wrong file: %s' % path)\n",
    "#     break\n",
    "df_min = pd.concat(df_min)\n",
    "df_min = df_min.reset_index(drop=True)\n",
    "\n",
    "df_min['datetime'] = pd.to_datetime(df_min['datetime'], infer_datetime_format=True)\n",
    "index_df = pd.read_csv(os.path.join(base_path, 'index.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 374/374 [00:01<00:00, 202.26it/s]\n"
     ]
    }
   ],
   "source": [
    "df_day = []\n",
    "data_path = os.path.join(base_path, 'day')\n",
    "for path in tqdm(os.listdir(data_path)):\n",
    "    df = pd.read_csv(os.path.join(data_path, path))\n",
    "    df['ts_code'] = path[:-4]\n",
    "    df_day.append(df)\n",
    "df_day = pd.concat(df_day)\n",
    "\n",
    "# 有些合约一出来就是主力合约，没有pre_close，用open替换\n",
    "idx = df_day['pre_close'].isna()\n",
    "df_day.loc[idx, 'pre_close'] = df_day.loc[idx, 'open']\n",
    "df_day = df_day[df_day['vol']>=500]\n",
    "df_day = df_day.reset_index(drop=True)\n",
    "df_day = df_day.rename(columns={'high':'today_high', 'low':'today_low'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 所有的交易日\n",
    "day_list = sorted(df_day['trade_date'].unique())\n",
    "day2idx = dict(zip(day_list, range(len(day_list))))\n",
    "idx2day = dict(zip(range(len(day_list)), day_list))\n",
    "day_set = set(day_list)\n",
    "\n",
    "# 获取所属交易日， 在21:00至00:00属于下个交易日\n",
    "def get_day(x):\n",
    "    # 如果当天大于20点，算新的一天\n",
    "    day = x.year * 10000 + x.month * 100 + x.day\n",
    "    if x.hour > 20:\n",
    "        idx = day2idx[day]\n",
    "        day = idx2day[idx + 1]\n",
    "\n",
    "    # 在凌晨的时候，周六凌晨也会交易\n",
    "    if x.hour < 8:\n",
    "        if day not in day_set:\n",
    "            # 说明是在非工作日凌晨交易,并且属于下一个交易日\n",
    "            x = x - pd.Timedelta(days=1)\n",
    "            day = x.year * 10000 + x.month * 100 + x.day\n",
    "            idx = day2idx[day]\n",
    "            day = idx2day[idx + 1]\n",
    "\n",
    "    return day\n",
    "\n",
    "df_min['trade_date'] = df_min['datetime'].apply(get_day)\n",
    "df_min = df_min.merge(df_day[['ts_code', 'trade_date', 'pre_close', 'today_high', 'today_low']], on=['ts_code', 'trade_date'], how='left')\n",
    "# 剔除无昨日开盘价数据，这类数据几乎整天无交易\n",
    "df_min = df_min[~df_min['pre_close'].isna()].reset_index(drop=True)\n",
    "df_min['mean'] = (df_min['high'] + df_min['low'] + df_min['close'] + df_min['open']) / 4\n",
    "df_min['rate'] = (df_min['mean'] - df_min['pre_close']) / df_min['pre_close']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# index feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_df = index_df.sort_values(['datetime'], ascending=[True])\n",
    "index_df.columns = ['datetime', 'index_rate']\n",
    "for day in [1, 2, 3, 4]:\n",
    "    index_df ['index_rate_shift'+str(day)] = (index_df['index_rate'].shift(day) - index_df['index_rate']) / index_df['index_rate']\n",
    "index_df['datetime'] = pd.to_datetime(index_df['datetime'], infer_datetime_format=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# normal feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 338/338 [00:27<00:00, 12.16it/s]\n"
     ]
    }
   ],
   "source": [
    "df_list = []\n",
    "# num = 0\n",
    "for i, g in tqdm(df_min.groupby(['ts_code'])):\n",
    "#     num += 1\n",
    "#     if num % 200 == 199:\n",
    "#         print(num)\n",
    "    g = g.sort_values(['datetime'], ascending=[True]).reset_index(drop=True)\n",
    "\n",
    "    col = ['ts_code', 'datetime']\n",
    "    for day in [5, 10, 15, 30, 60]:\n",
    "        g['max_' + str(day)] = g['high'].rolling(day).max()\n",
    "        g['min_' + str(day)] = g['low'].rolling(day).min()\n",
    "        g['mean_' + str(day)] = g['mean'].rolling(day).mean()\n",
    "        g['volume_' + str(day)] = g['volume'].rolling(day).sum()\n",
    "\n",
    "        col.append('max_' + str(day))\n",
    "        col.append('min_' + str(day))\n",
    "        col.append('mean_' + str(day))\n",
    "        col.append('volume_' + str(day))\n",
    "\n",
    "    # 再处理\n",
    "    for day in [5, 10, 15, 30, 60]:\n",
    "        g['max_' + str(day)] = (g['mean'] - g['max_' + str(day)]) / g['max_' + str(day)]\n",
    "        g['min_' + str(day)] = (g['mean'] - g['min_' + str(day)]) / g['min_' + str(day)]\n",
    "        g['mean_' + str(day)] = (g['mean'] - g['mean_' + str(day)]) / g['mean_' + str(day)]\n",
    "        g['volume_' + str(day)] = (g['volume'] - g['volume_' + str(day)]) / g['volume_' + str(day)]\n",
    "\n",
    "\n",
    "    for day in [1, 2, 3]:\n",
    "        g['high_shift' + str(day)] = (g['high'].shift(day) - g['mean']) / g['mean']\n",
    "        g['low_shift' + str(day)] = (g['low'].shift(day) - g['mean']) / g['mean']\n",
    "        g['mean_shift' + str(day)] = (g['mean'].shift(day) - g['mean']) / g['mean']\n",
    "\n",
    "        g['high_shift2' + str(day)] = (g['high'].shift(day) - g['pre_close']) / g['pre_close']\n",
    "        g['low_shift2' + str(day)] = (g['low'].shift(day) - g['pre_close']) / g['pre_close']\n",
    "        g['mean_shift2' + str(day)] = (g['mean'].shift(day) - g['pre_close']) / g['pre_close']\n",
    "\n",
    "        col.append('high_shift' + str(day))\n",
    "        col.append('low_shift' + str(day))\n",
    "        col.append('mean_shift' + str(day))\n",
    "\n",
    "        col.append('high_shift2' + str(day))\n",
    "        col.append('low_shift2' + str(day))\n",
    "        col.append('mean_shift2' + str(day))\n",
    "\n",
    "    df_list.append(g[col])\n",
    "\n",
    "df_list = pd.concat(df_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_min = df_min.merge(df_list, on=['ts_code', 'datetime'], how='left')\n",
    "\n",
    "# def MakeFeature(df):\n",
    "df_min['high_t'] = (df_min['high'] - df_min['pre_close']) / df_min['pre_close']\n",
    "df_min['low_t'] = (df_min['low']-df_min['pre_close']) / df_min['pre_close']\n",
    "df_min['open_t'] = (df_min['open']-df_min['pre_close']) / df_min['pre_close']\n",
    "\n",
    "col.append('high_t')\n",
    "col.append('low_t')\n",
    "col.append('open_t')\n",
    "\n",
    "df_min = df_min.merge(index_df, on=['datetime'], how='left')\n",
    "\n",
    "col += ['index_rate', 'index_rate_shift1', 'index_rate_shift2', 'index_rate_shift3', 'index_rate_shift4']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_time(x):\n",
    "    # 夜盘21:00-2:30\n",
    "    h = x.hour\n",
    "    if h > 20 and h < 24:\n",
    "        t = (h-21)*60 + x.minute\n",
    "    elif h >= 0 and h < 3:\n",
    "        t = 180 + h*60 + x.minute\n",
    "    elif h > 8:\n",
    "        t = 330 + (h-9)*60 + x.minute\n",
    "    return t\n",
    "df_min['time'] = df_min['datetime'].apply(get_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make Label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_label(df_min):\n",
    "    # 20 分钟后的涨幅，且日内交易\n",
    "    delta_t = 20\n",
    "    df_label = []\n",
    "    tmp_col = ['ts_code', 'trade_date', 'datetime', 'mean', 'low', 'high']\n",
    "    for i, g in tqdm(df_min[tmp_col].groupby(['ts_code', 'trade_date'])):\n",
    "        g = g.sort_values('datetime', ascending=True).reset_index(drop=True)\n",
    "        g['mean_shift1'] = g['mean'].shift(-1)\n",
    "        g['mean_latter'] = g['mean'].shift(-delta_t)\n",
    "        g['datetime_open'] = g['datetime'].shift(-1) # 开仓时间\n",
    "        g['datetime_close'] = g['datetime'].shift(-delta_t) # 平仓时间\n",
    "        g['lowest'] = g.rolling(delta_t)['low'].min().shift(-delta_t)\n",
    "        g['highest'] = g.rolling(delta_t)['high'].max().shift(-delta_t)\n",
    "        df_label.append(g)\n",
    "    df_label = pd.concat(df_label)\n",
    "    return df_label\n",
    "\n",
    "if os.path.exists(os.path.join(base_path, 'label.csv')):\n",
    "    df_label = pd.read_csv(os.path.join(base_path, 'label.csv'))\n",
    "else:\n",
    "    df_label = get_label(df_min)\n",
    "    df_label.to_csv(os.path.join(base_path, 'label.csv'), index=None)\n",
    "df_label['datetime'] = pd.to_datetime(df_label['datetime'], infer_datetime_format=True)\n",
    "# df_label['label'] = df_label['mean_latter'] >= df_label['mean']\n",
    "df_label['label'] = df_label['mean_latter'] >= df_label['mean_shift1']\n",
    "df_label['return'] = (df_label['mean_latter'] - df_label['mean_shift1']) / df_label['mean_shift1']\n",
    "df_min = df_min.merge(df_label[['datetime', 'ts_code', 'label', 'return', 'datetime_open', 'datetime_close']], on=['datetime', 'ts_code'], how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_min.dropna(inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 假设 分钟行情的最低价、最高价、都等于当日最高价视为涨停或跌停\n",
    "df_min['limit'] = 0\n",
    "idx = (df_min['low']==df_min['high'])&((df_min['low']==df_min['today_low'])|(df_min['high']==df_min['today_high']))\n",
    "df_min.loc[idx, 'limit'] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "46\n"
     ]
    }
   ],
   "source": [
    "feature_col = col[2:]\n",
    "print(len(feature_col))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_col = 'label'\n",
    "\n",
    "trn_date_min = 20190101\n",
    "trn_date_max = 20190901\n",
    "\n",
    "test_date_min = 20190902\n",
    "test_date_max = 20200501\n",
    "\n",
    "# 训练集中需要提出涨幅为0（可能为涨停或跌停）\n",
    "\n",
    "trn_idx = (df_min['trade_date'] >= trn_date_min) & (df_min['trade_date'] <= trn_date_max) & (df_min['return']!=0)\n",
    "test_idx = (df_min['trade_date'] >= test_date_min) & (df_min['trade_date'] <= test_date_max) & (df_min['limit']==0)\n",
    "\n",
    "trn = df_min[trn_idx][feature_col].values\n",
    "trn_label = df_min[trn_idx][label_col].values\n",
    "\n",
    "test = df_min[test_idx][feature_col].values\n",
    "test_label = df_min[test_idx][label_col].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:1093408, 1:1090112\n",
      "rate 0:0.501, rate 1:0.499\n",
      "len trn:2183520, len test:2065200\n"
     ]
    }
   ],
   "source": [
    "num_1 = np.sum(trn_label)\n",
    "num_0 = len(trn_label)-np.sum(trn_label)\n",
    "print('0:%d, 1:%d' %(num_0, num_1))\n",
    "print('rate 0:%.3f, rate 1:%.3f' % (num_0/len(trn_label), num_1/len(trn_label)))\n",
    "print('len trn:%d, len test:%d' % (len(trn), len(test)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60.926990270614624\n"
     ]
    }
   ],
   "source": [
    "# 模型训练及评价\n",
    "import lightgbm as lgb\n",
    "import time\n",
    "t1 = time.time()\n",
    "from sklearn import metrics\n",
    "param = {'num_leaves': 31,\n",
    "         'min_data_in_leaf': 20,\n",
    "         'objective': 'binary',\n",
    "         'learning_rate': 0.06,\n",
    "         \"boosting\": \"gbdt\",\n",
    "         \"metric\": 'None',\n",
    "         \"verbosity\": -1}\n",
    "trn_data = lgb.Dataset(trn, trn_label)\n",
    "num_round =888\n",
    "\n",
    "clf = lgb.train(param, trn_data, 400, verbose_eval=300)\n",
    "t2 = time.time()\n",
    "print(t2-t1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>column</th>\n",
       "      <th>importance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>index_rate</td>\n",
       "      <td>2174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>index_rate_shift4</td>\n",
       "      <td>740</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>max_60</td>\n",
       "      <td>670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>min_60</td>\n",
       "      <td>663</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>mean_60</td>\n",
       "      <td>596</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>index_rate_shift1</td>\n",
       "      <td>499</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>index_rate_shift2</td>\n",
       "      <td>483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>index_rate_shift3</td>\n",
       "      <td>472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>max_30</td>\n",
       "      <td>472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>min_30</td>\n",
       "      <td>365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>high_t</td>\n",
       "      <td>333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>mean_30</td>\n",
       "      <td>315</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>low_t</td>\n",
       "      <td>310</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>min_15</td>\n",
       "      <td>309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>max_15</td>\n",
       "      <td>293</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>mean_10</td>\n",
       "      <td>196</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>max_10</td>\n",
       "      <td>172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>high_shift23</td>\n",
       "      <td>170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>min_10</td>\n",
       "      <td>159</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>volume_60</td>\n",
       "      <td>157</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>low_shift23</td>\n",
       "      <td>155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>mean_15</td>\n",
       "      <td>153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>mean_shift1</td>\n",
       "      <td>141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>high_shift21</td>\n",
       "      <td>140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>open_t</td>\n",
       "      <td>137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>low_shift21</td>\n",
       "      <td>135</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>min_5</td>\n",
       "      <td>129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>low_shift1</td>\n",
       "      <td>126</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>low_shift22</td>\n",
       "      <td>118</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>max_5</td>\n",
       "      <td>112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>volume_5</td>\n",
       "      <td>110</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>high_shift22</td>\n",
       "      <td>107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>high_shift1</td>\n",
       "      <td>97</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>volume_30</td>\n",
       "      <td>92</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>mean_shift23</td>\n",
       "      <td>90</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>volume_15</td>\n",
       "      <td>88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>low_shift3</td>\n",
       "      <td>64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>volume_10</td>\n",
       "      <td>64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>mean_shift22</td>\n",
       "      <td>60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>mean_shift21</td>\n",
       "      <td>59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>high_shift2</td>\n",
       "      <td>56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>mean_shift3</td>\n",
       "      <td>50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>high_shift3</td>\n",
       "      <td>48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>mean_5</td>\n",
       "      <td>48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>low_shift2</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>mean_shift2</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               column  importance\n",
       "41         index_rate        2174\n",
       "45  index_rate_shift4         740\n",
       "16             max_60         670\n",
       "17             min_60         663\n",
       "18            mean_60         596\n",
       "42  index_rate_shift1         499\n",
       "43  index_rate_shift2         483\n",
       "44  index_rate_shift3         472\n",
       "12             max_30         472\n",
       "13             min_30         365\n",
       "38             high_t         333\n",
       "14            mean_30         315\n",
       "39              low_t         310\n",
       "9              min_15         309\n",
       "8              max_15         293\n",
       "6             mean_10         196\n",
       "4              max_10         172\n",
       "35       high_shift23         170\n",
       "5              min_10         159\n",
       "19          volume_60         157\n",
       "36        low_shift23         155\n",
       "10            mean_15         153\n",
       "22        mean_shift1         141\n",
       "23       high_shift21         140\n",
       "40             open_t         137\n",
       "24        low_shift21         135\n",
       "1               min_5         129\n",
       "21         low_shift1         126\n",
       "30        low_shift22         118\n",
       "0               max_5         112\n",
       "3            volume_5         110\n",
       "29       high_shift22         107\n",
       "20        high_shift1          97\n",
       "15          volume_30          92\n",
       "37       mean_shift23          90\n",
       "11          volume_15          88\n",
       "33         low_shift3          64\n",
       "7           volume_10          64\n",
       "31       mean_shift22          60\n",
       "25       mean_shift21          59\n",
       "26        high_shift2          56\n",
       "34        mean_shift3          50\n",
       "32        high_shift3          48\n",
       "2              mean_5          48\n",
       "27         low_shift2          39\n",
       "28        mean_shift2          34"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame({\n",
    "        'column': feature_col,\n",
    "        'importance': clf.feature_importance(),\n",
    "    }).sort_values(by='importance', ascending=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型保存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists('model'):\n",
    "    os.path.mkdir('model')\n",
    "import pickle\n",
    "save_path = 'model/model.pickle'\n",
    "with open(save_path, 'wb') as f:\n",
    "    pickle.dump(clf, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 结果分析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = df_min[test_idx][['datetime', 'ts_code', label_col, 'return', 'mean', 'datetime_open', 'datetime_close']]\n",
    "test_lgb = clf.predict(test, num_iteration=clf.best_iteration)\n",
    "test_df['pred'] = test_lgb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tp:1128, pp:2067, precision1: 0.5457\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.0007325614776948953"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df['pred2'] = test_df['pred']>0.8\n",
    "tp = np.sum((test_df['pred2']==True)&(test_df['return']>0))\n",
    "pp = np.sum(test_df['pred2']==True)\n",
    "precison1 = tp/(pp+0.001)\n",
    "print('tp:%d, pp:%d, precision1: %.4f' % (tp, pp, precison1))\n",
    "test_df[test_df['pred2']==True]['return'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn:62, pn:83, precision2: 0.7470\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-0.0015078313189174908"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df['pred2'] = test_df['pred']>0.2\n",
    "tn = np.sum((test_df['pred2']==False)&(test_df['return']<0))\n",
    "pn = np.sum(test_df['pred2']==False)\n",
    "precison2 = tn/(pn+0.001)\n",
    "print('tn:%d, pn:%d, precision2: %.4f' % (tn, pn, precison2))\n",
    "test_df[test_df['pred2']==False]['return'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 结果保存\n",
    "test_df.to_csv(os.path.join(base_path, 'result.csv'), index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
